/**
 * 2017年5月22日
 */
package cn.edu.bjtu.model.core;

import java.io.File;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;

import cn.edu.bjtu.core.ClassificationPair;
import cn.edu.bjtu.core.Constant;
import cn.edu.bjtu.core.LoggerSupport;


/**
 * 类标名称管理器
 * 一个网络输出的结果应该是一个向量,(向量长度应该与这个类里面labels里面元素个数相相等,所以训练网络的时候需要更新这里面的数据.)
 * 这个向量代表了每个类的概率
 * 每个类名是什么
 * 就是用这个类来对应上
 * @author Alex
 *
 */
public class TextCategorizationManager extends LoggerSupport{
	
	private static TextCategorizationManager inst = null;
	public static TextCategorizationManager get(){
		if(inst == null){
			restore();
		}
		return inst;
	}
	
	private List<String> labels;
	
	public synchronized TextCategorizationManager setLabels(List<String> labels){
		this.labels = new ArrayList<String>(labels);
		return this;
	}
	
	public static synchronized TextCategorizationManager restore(){
		try{
			InputStream f = TextCategorizationManager.class.getClassLoader().getResourceAsStream("meta/"+Constant.Values.LabelsFileName);
			List<String> list = (List<String>)SerializationUtils.readObject(f);
			if(inst != null){
				inst.labels = list;
				return inst;
			}else{
				inst = new TextCategorizationManager();
				inst.labels = list;
			}
			return inst;
		}catch(Exception e){
			e.printStackTrace();
		}
		return inst;
	}
	/**
	 * 不涉及到全局变量及实例变量,所以是线程安全的.
	 * @param arr
	 * @return
	 */
	public ClassificationPair[] getDesc(INDArray arr){
		ClassificationPair res [] = new ClassificationPair[this.labels.size()];
		for(int i=0;i<res.length;i++){
			res[i] = new ClassificationPair(this.labels.get(i), arr.getDouble(i));
		}
		Arrays.<ClassificationPair>sort(res,(x,y)->{
			double t = y.getVal()-x.getVal();
			return t>0?1:(t<0?-1:0);
		});
		return res;
	}
	public boolean checkLabel(String label){
		return this.labels.contains(label);
	}

	public static void main(String args[]){
		List<String >lables = new ArrayList<>();
		File fs[] = new File("D:\\Share\\answer").listFiles();
		for(File f:fs){
			lables.add(f.getName());
		}
		Collections.sort(lables);
		//TextCategorizationManager.get().setLabels(lables).saveLabels();
		//System.out.println(TextCategorizationManager.get() ==TextCategorizationManager.restore());
		//System.out.println(TextCategorizationManager.restore().inst.labels);

//		System.out.println(TextCategorizationManager.get() == );
//		System.out.println(TextCategorizationManager.get().labels);
		
	}
	
}
